{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.ensemble import GradientBoostingClassifier as GBC\n",
    "from sklearn.ensemble import RandomForestClassifier as RFC\n",
    "from sklearn.tree import DecisionTreeClassifier as DTC\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.datasets import load_wine\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>alcohol</th>\n",
       "      <th>malic_acid</th>\n",
       "      <th>ash</th>\n",
       "      <th>alcalinity_of_ash</th>\n",
       "      <th>magnesium</th>\n",
       "      <th>total_phenols</th>\n",
       "      <th>flavanoids</th>\n",
       "      <th>nonflavanoid_phenols</th>\n",
       "      <th>proanthocyanins</th>\n",
       "      <th>color_intensity</th>\n",
       "      <th>hue</th>\n",
       "      <th>od280/od315_of_diluted_wines</th>\n",
       "      <th>proline</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>14.23</td>\n",
       "      <td>1.71</td>\n",
       "      <td>2.43</td>\n",
       "      <td>15.6</td>\n",
       "      <td>127.0</td>\n",
       "      <td>2.80</td>\n",
       "      <td>3.06</td>\n",
       "      <td>0.28</td>\n",
       "      <td>2.29</td>\n",
       "      <td>5.64</td>\n",
       "      <td>1.04</td>\n",
       "      <td>3.92</td>\n",
       "      <td>1065.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>13.20</td>\n",
       "      <td>1.78</td>\n",
       "      <td>2.14</td>\n",
       "      <td>11.2</td>\n",
       "      <td>100.0</td>\n",
       "      <td>2.65</td>\n",
       "      <td>2.76</td>\n",
       "      <td>0.26</td>\n",
       "      <td>1.28</td>\n",
       "      <td>4.38</td>\n",
       "      <td>1.05</td>\n",
       "      <td>3.40</td>\n",
       "      <td>1050.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>13.16</td>\n",
       "      <td>2.36</td>\n",
       "      <td>2.67</td>\n",
       "      <td>18.6</td>\n",
       "      <td>101.0</td>\n",
       "      <td>2.80</td>\n",
       "      <td>3.24</td>\n",
       "      <td>0.30</td>\n",
       "      <td>2.81</td>\n",
       "      <td>5.68</td>\n",
       "      <td>1.03</td>\n",
       "      <td>3.17</td>\n",
       "      <td>1185.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>14.37</td>\n",
       "      <td>1.95</td>\n",
       "      <td>2.50</td>\n",
       "      <td>16.8</td>\n",
       "      <td>113.0</td>\n",
       "      <td>3.85</td>\n",
       "      <td>3.49</td>\n",
       "      <td>0.24</td>\n",
       "      <td>2.18</td>\n",
       "      <td>7.80</td>\n",
       "      <td>0.86</td>\n",
       "      <td>3.45</td>\n",
       "      <td>1480.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>13.24</td>\n",
       "      <td>2.59</td>\n",
       "      <td>2.87</td>\n",
       "      <td>21.0</td>\n",
       "      <td>118.0</td>\n",
       "      <td>2.80</td>\n",
       "      <td>2.69</td>\n",
       "      <td>0.39</td>\n",
       "      <td>1.82</td>\n",
       "      <td>4.32</td>\n",
       "      <td>1.04</td>\n",
       "      <td>2.93</td>\n",
       "      <td>735.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   alcohol  malic_acid   ash  ...   hue  od280/od315_of_diluted_wines  proline\n",
       "0    14.23        1.71  2.43  ...  1.04                          3.92   1065.0\n",
       "1    13.20        1.78  2.14  ...  1.05                          3.40   1050.0\n",
       "2    13.16        2.36  2.67  ...  1.03                          3.17   1185.0\n",
       "3    14.37        1.95  2.50  ...  0.86                          3.45   1480.0\n",
       "4    13.24        2.59  2.87  ...  1.04                          2.93    735.0\n",
       "\n",
       "[5 rows x 13 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data=load_wine()\n",
    "x=data.data\n",
    "y=data.target\n",
    "x=pd.DataFrame(x,columns=data.feature_names)\n",
    "x.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.3,random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练集准确率: 1.0\n",
      "测试集准确率: 0.9074074074074074\n"
     ]
    }
   ],
   "source": [
    "\n",
    "gbc=GBC()\n",
    "gbc.fit(x_train,y_train)\n",
    "\n",
    "train_score=gbc.score(x_train,y_train)\n",
    "test_score=gbc.score(x_test,y_test)\n",
    "\n",
    "print(\"训练集准确率:\",train_score)\n",
    "print(\"测试集准确率:\",test_score)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 1, 0, 1, 0, 1, 2, 1, 0])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gbc.predict(x_test[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import mean_squared_error\n",
    "from sklearn.ensemble import GradientBoostingRegressor as GBR\n",
    "from sklearn.datasets import fetch_california_housing\n",
    "from sklearn.tree import DecisionTreeRegressor as DTR\n",
    "from sklearn.ensemble import RandomForestRegressor as RFR\n",
    "from sklearn.model_selection import cross_val_score\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   MedInc  HouseAge  AveRooms  ...  AveOccup  Latitude  Longitude\n",
      "0  8.3252      41.0  6.984127  ...  2.555556     37.88    -122.23\n",
      "1  8.3014      21.0  6.238137  ...  2.109842     37.86    -122.22\n",
      "2  7.2574      52.0  8.288136  ...  2.802260     37.85    -122.24\n",
      "3  5.6431      52.0  5.817352  ...  2.547945     37.85    -122.25\n",
      "4  3.8462      52.0  6.281853  ...  2.181467     37.85    -122.25\n",
      "\n",
      "[5 rows x 8 columns]\n"
     ]
    }
   ],
   "source": [
    "x,y=fetch_california_housing(return_X_y=True,as_frame=True)\n",
    "print(x.head())\n",
    "x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.3,random_state=42)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7803117736943137"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gbr=GBR(random_state=0)\n",
    "gbr.fit(x_train,y_train)\n",
    "gbr.score(x_test,y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "modelname=['DecisionTree','GBDT','RF-D']\n",
    "models=[DTR(random_state=0),GBR(random_state=0),RFR(random_state=0,max_depth=3)]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DecisionTree\n",
      "\t MES:-0.8176\n",
      "\t Time:1.0814\n",
      "\n",
      "\n",
      "GBDT\n",
      "\t MES:-0.4124\n",
      "\t Time:18.7732\n",
      "\n",
      "\n",
      "RF-D\n",
      "\t MES:-0.6385\n",
      "\t Time:12.3985\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for name,model in zip(modelname,models):\n",
    "    start=time.time()\n",
    "    result=cross_val_score(model,x,y,cv=5,scoring='neg_mean_squared_error')\n",
    "    end=time.time()-start\n",
    "    print(name)\n",
    "    print(\"\\t MES:{:.4f}\".format(result.mean()))\n",
    "    print(\"\\t Time:{:.4f}\".format(end))\n",
    "    print('\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
